{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- _c0: string (nullable = true)\n",
      " |-- _c1: string (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.sql import SparkSession\n",
    "\n",
    "spark = SparkSession.builder.appName('nlp').getOrCreate()\n",
    "df = spark.read.csv('SMSSpamCollection', inferSchema = True, sep = '\\t')\n",
    "df.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+--------------------+\n",
      "|class|                text|\n",
      "+-----+--------------------+\n",
      "|  ham|Go until jurong p...|\n",
      "|  ham|Ok lar... Joking ...|\n",
      "| spam|Free entry in 2 a...|\n",
      "|  ham|U dun say so earl...|\n",
      "|  ham|Nah I don't think...|\n",
      "| spam|FreeMsg Hey there...|\n",
      "|  ham|Even my brother i...|\n",
      "|  ham|As per your reque...|\n",
      "| spam|WINNER!! As a val...|\n",
      "| spam|Had your mobile 1...|\n",
      "|  ham|I'm gonna be home...|\n",
      "| spam|SIX chances to wi...|\n",
      "| spam|URGENT! You have ...|\n",
      "|  ham|I've been searchi...|\n",
      "|  ham|I HAVE A DATE ON ...|\n",
      "| spam|XXXMobileMovieClu...|\n",
      "|  ham|Oh k...i'm watchi...|\n",
      "|  ham|Eh u remember how...|\n",
      "|  ham|Fine if thats th...|\n",
      "| spam|England v Macedon...|\n",
      "+-----+--------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df = df.withColumnRenamed('_c0', 'class').withColumnRenamed('_c1', 'text')\n",
    "df.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+--------------------+------+\n",
      "|class|                text|length|\n",
      "+-----+--------------------+------+\n",
      "|  ham|Go until jurong p...|   111|\n",
      "|  ham|Ok lar... Joking ...|    29|\n",
      "| spam|Free entry in 2 a...|   155|\n",
      "+-----+--------------------+------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.sql.functions import length\n",
    "\n",
    "df = df.withColumn('length', length(df['text']))\n",
    "df.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+-----------------+\n",
      "|class|      avg(length)|\n",
      "+-----+-----------------+\n",
      "|  ham|71.45431945307645|\n",
      "| spam|138.6706827309237|\n",
      "+-----+-----------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('class').mean().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import (CountVectorizer, Tokenizer, \n",
    "                                StopWordsRemover, IDF, StringIndexer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = Tokenizer(inputCol = 'text', outputCol = 'token_text')\n",
    "stop_remove = StopWordsRemover(inputCol = 'token_text', outputCol = 'stop_token')\n",
    "count_vec = CountVectorizer(inputCol = 'stop_token', outputCol = 'c_vec')\n",
    "idf = IDF(inputCol = 'c_vec', outputCol = 'tf_idf')\n",
    "ham_spam_to_numeric = StringIndexer(inputCol = 'class', outputCol = 'label')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import VectorAssembler\n",
    "\n",
    "clean_up = VectorAssembler(inputCols = ['tf_idf', 'length'], outputCol = 'features')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import NaiveBayes\n",
    "\n",
    "nb = NaiveBayes()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml import Pipeline\n",
    "\n",
    "pipeline = Pipeline(stages=[ham_spam_to_numeric, tokenizer, stop_remove, count_vec, idf, clean_up])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+--------------------+\n",
      "|label|            features|\n",
      "+-----+--------------------+\n",
      "|  0.0|(13424,[7,11,31,6...|\n",
      "|  0.0|(13424,[0,24,297,...|\n",
      "|  1.0|(13424,[2,13,19,3...|\n",
      "+-----+--------------------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "cleaner = pipeline.fit(df)\n",
    "clean_df = cleaner.transform(df)\n",
    "clean_df = clean_df.select('label', 'features')\n",
    "clean_df.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = clean_df.randomSplit([0.7, 0.3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- class: string (nullable = true)\n",
      " |-- text: string (nullable = true)\n",
      " |-- length: integer (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+--------------------+--------------------+--------------------+----------+\n",
      "|label|            features|       rawPrediction|         probability|prediction|\n",
      "+-----+--------------------+--------------------+--------------------+----------+\n",
      "|  0.0|(13424,[0,1,2,41,...|[-1061.8164592374...|[1.0,1.5335777452...|       0.0|\n",
      "|  0.0|(13424,[0,1,5,20,...|[-815.00170676097...|[1.0,3.5993277744...|       0.0|\n",
      "|  0.0|(13424,[0,1,9,14,...|[-539.27732500216...|[1.0,6.4723637628...|       0.0|\n",
      "+-----+--------------------+--------------------+--------------------+----------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "spam_detector = nb.fit(train)\n",
    "predictions = spam_detector.transform(test)\n",
    "predictions.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Accuracy: 0.9149191132414619\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n",
    "\n",
    "evaluator = MulticlassClassificationEvaluator()\n",
    "print(\"Test Accuracy: \" + str(evaluator.evaluate(predictions, {evaluator.metricName: \"accuracy\"})))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
